package cmd

import (
	"fmt"
	"go/importer"
	"go/token"
	"go/types"
	"io/ioutil"
	"log"
	"os"
	"strings"

	"github.com/dave/jennifer/jen"
	"github.com/influxdata/flux/codes"
	"github.com/influxdata/flux/internal/errors"
	"github.com/spf13/cobra"
)

// semanticCmd represents the semantic command
var semanticCmd = &cobra.Command{
	Use:   "semantic",
	Short: "Generate translation code from semantic FlatBuffers to Go semantic graph",
	RunE:  generateSemantic,
	Long: `Generate translation code from semantic FlatBuffers to Go semantic graph struct types.

This tool will generate a "ToBuf()" method for every struct type in semantic/graph.go that implements the Node interface,
with a few exceptions where a FlatBuffers equivalent does not exist.
For each field in each such struct, code is generated to read from the FlatBuffers representation of the semantic graph
and assign the corresponding field in the Go struct.

In most cases, this was possible to do with completely generated code.  In others, the generated code will call into
handwritten code in semantic/flatbuffers.go.
`,
}

var (
	flagOutput            string
	flagContinueOnMissing bool
)

func init() {
	rootCmd.AddCommand(semanticCmd)
	semanticCmd.Flags().StringVar(&flagOutput, "output", "",
		"When present, output will be written to the specified file. Otherwise, output is piped to stdout.")
	semanticCmd.Flags().BoolVar(&flagContinueOnMissing, "continue-on-missing", false,
		`When set, "TODO" comments will be put in generated code for semantic graph nodes that are not yet handled by this tool, instead of terminating with an error.`)
}

const (
	semPath    = "github.com/influxdata/flux/semantic"
	fbsemPath  = "github.com/influxdata/flux/semantic/internal/fbsemantic"
	errorsPath = "github.com/influxdata/flux/internal/errors"
	codesPath  = "github.com/influxdata/flux/codes"
)

func generateSemantic(cmd *cobra.Command, args []string) error {
	if flagOutput != "" {
		// If there are errors in the generated file because its currently under development,
		// replace it with something that won't trip up the call to Import() below.
		if err := ioutil.WriteFile(flagOutput, []byte("package semantic\n"), 0644); err != nil {
			return err
		}
	}
	// Fetch a representation of the Flux repo's semantic package.
	pkg, err := importer.ForCompiler(&token.FileSet{}, "source", nil).Import(semPath)
	if err != nil {
		log.Fatal(err)
	}

	f := jen.NewFilePath(semPath)
	f.Commentf("DO NOT EDIT.  This file was generated by the fbgen command.")
	f.Line()

	// loop over each name in the semantic package: this includes all the types, variables, constants and functions.
	// In this case we only care about the struct types that implement the Node interface.
	scope := pkg.Scope()
	for _, n := range scope.Names() {
		o := scope.Lookup(n)
		if skipObject(o) {
			continue
		}

		// Type of o will be a NamedType if we get here.  Get the underlying struct type,
		// so we can loop over the fields.
		st, ok := o.Type().Underlying().(*types.Struct)
		if !ok {
			continue
		}

		// cs will be populated with the code for the function body of this struct type's "FromBuf" method.
		cs := make([]jen.Code, st.NumFields())
		// First generate some builerplate:
		//   var err
		//   if fb == nil {
		//     return nil
		//   }
		cs = append(cs, jen.Var().Err().Error())
		cs = append(cs,
			jen.If(jen.Id("fb").Op("==").Nil()).Block(
				jen.Return(jen.Nil()),
			),
		)
		// Generate code that populates each field of the struct from the FlatBuffer representation
		for i := 0; i < st.NumFields(); i++ {
			c, err := doField(o, st.Field(i))
			if err != nil {
				return err
			}
			if c != nil {
				cs = append(cs, c...)
			}
		}
		//   return nil
		cs = append(cs, jen.Return(jen.Nil()))

		// Wrap the function body with a function definition:
		//   func (rcv *<struct name>) FromBuf(fb *fbsemantic.<struct name>) error { ... }
		f.Func().Params(jen.Id("rcv").Op("*").Id(o.Name())).
			Id("FromBuf").
			Params(jen.Id("fb").Op("*").Qual(fbsemPath, o.Name())).
			Error().
			Block(cs...)
		f.Line()
	}

	if flagOutput != "" {
		bs := []byte(fmt.Sprintf("%#v", f))
		if err := ioutil.WriteFile(flagOutput, bs, 0644); err != nil {
			return errors.Wrap(err, codes.Internal, "could not write output")
		}
		_, _ = fmt.Fprintf(os.Stderr, "Generated file %v\n", flagOutput)
	} else {
		fmt.Printf("%#v", f)
	}
	return nil
}

// handleMissingf is called when this tool finds a field type that cannot be handled.
// if flagContinueOnMissing is set, a "to do" comment will be generated and appended to cs.
// Otherwise, an error will be returned.
// format string and subsequent arguments are similar to the fmt.Printf function.
func handleMissingf(cs []jen.Code, format string, a ...interface{}) ([]jen.Code, error) {
	if flagContinueOnMissing {
		cs = append(cs,
			jen.Commentf("TODO: "+format, a...),
		)
		return cs, nil
	}
	return nil, errors.Newf(codes.Internal, format, a...)
}

// doField returns a slice of Go code that populates one field of a Go semantic graph struct.
// It just dispatches to more other functions that handle specific kinds of types, defined below.
func doField(o types.Object, field *types.Var) ([]jen.Code, error) {
	codes := make([]jen.Code, 0, 1)
	switch t := field.Type().(type) {
	case *types.Named:
		cs, err := doNamed(o, field)
		if err != nil {
			return nil, err
		}
		codes = append(codes, cs...)
	case *types.Basic:
		cs, err := doBasic(o, field)
		if err != nil {
			return nil, err
		}
		codes = append(codes, cs...)
	case *types.Pointer:
		cs, err := doPointer(o, field)
		if err != nil {
			return nil, err
		}
		codes = append(codes, cs...)
	case *types.Slice:
		cs, err := doSlice(o, field)
		if err != nil {
			return nil, err
		}
		codes = append(codes, cs...)
	default:
		var err error
		if codes, err = handleMissingf(codes, "unknown type kind: %#v", t); err != nil {
			return nil, err
		}
	}
	return codes, nil
}

// doNamed returns a slice of Go code that populates a field in a Go semantic graph struct,
// when the Go field type is either an interface or a named struct.
func doNamed(o types.Object, field *types.Var) ([]jen.Code, error) {
	fieldForError := o.Name() + "." + field.Name()
	t := field.Type().(*types.Named)
	var codes []jen.Code
	switch fieldType := t.Obj().Name(); fieldType {
	case "loc":
		codes = append(codes,
			jen.If(
				jen.Id("fbLoc").Op(":=").Id("fb").Dot("Loc").Params(jen.Nil()),
				jen.Id("fbLoc").Op("!=").Nil(),
			).Block(
				ifErrorPropagate(
					jen.Id("rcv").Dot("loc").Dot("FromBuf").Params(jen.Id("fbLoc")),
					o.Name()+"."+field.Name(),
				),
			),
		)
	case "Expression", "Assignment":
		helperFn := "from" + fieldType + "Table"
		if isOptionalExpr(o.Name(), field.Name()) {
			helperFn += "Optional"
		}
		fbField := toFBName(o.Name(), field.Name())
		codes = append(codes,
			ifErrorPropagate(
				jen.Id(helperFn).Params(
					jen.Id("fb").Dot(fbField),
					jen.Id("fb").Dot(fbConcat(fbField, "Type")).Params(),
				),
				fieldForError,
				jen.Id("rcv").Dot(field.Name()),
			),
		)
	case "PropertyKey":
		// PropertyKey can only be an identifier in Rust and FB, but is an interface
		// (string literal or identifier) in Go.
		codes = append(codes,
			ifErrorPropagate(
				jen.Id("propertyKeyFromFBIdentifier").Params(
					jen.Id("fb").Dot(field.Name()).Params(jen.Nil()),
				),
				fieldForError,
				jen.Id("rcv").Dot(field.Name()),
			),
		)
	case "OperatorKind":
		codes = append(codes,
			ifErrorPropagate(
				jen.Id("fromFBOperator").Params(jen.Id("fb").Dot(field.Name()).Params()),
				fieldForError,
				jen.Id("rcv").Dot(field.Name()),
			),
		)
	case "LogicalOperatorKind":
		codes = append(codes,
			ifErrorPropagate(
				jen.Id("fromFBLogicalOperator").Params(jen.Id("fb").Dot(field.Name()).Params()),
				fieldForError,
				jen.Id("rcv").Dot(field.Name()),
			),
		)
	case "Time":
		fbVar := jen.Id("fb" + field.Name())
		codes = append(codes,
			jen.If(
				fbVar.Clone().Op(":=").Id("fb").Dot(field.Name()).Params(jen.Nil()),
				fbVar.Clone().Op("!=").Nil(),
			).Block(
				jen.Id("rcv").Dot(field.Name()).Op("=").Id("fromFBTime").Params(fbVar),
			),
		)
	case "Node":
		codes = append(codes,
			ifErrorPropagate(
				jen.Id("fromFBBlock").Params(jen.Id("fb")),
				fieldForError,
				jen.Id("rcv").Dot(field.Name()),
			),
		)
	default:
		var err error
		if codes, err = handleMissingf(codes, "unknown named type: %q", fieldType); err != nil {
			return nil, err
		}
	}
	return codes, nil
}

// fbNameMap is keyed on {<struct-name>, <field-name>} for Go semantic graph types,
// and contains the equivalent field name in the FlatBuffers representation.  Some of
// them do not match exactly, so this map is used to account for the differences.
var fbNameMap = map[[2]string]string{
	{"NativeVariableAssignment", "Init"}: "Init_",
	{"MemberAssignment", "Init"}:         "Init_",
	{"BuiltinStatement", "ID"}:           "Id",
	{"DurationLiteral", "Values"}:        "Value",
	{"ImportDeclaration", "As"}:          "Alias",
}

// toFBName accepts the name of a Go struct and one of its field names,
// and returns the equivalent name of the field in the FlatBuffers representation,
// using fbNameMap.
func toFBName(structName, fieldName string) string {
	key := [2]string{structName, fieldName}
	if fbn, ok := fbNameMap[key]; ok {
		return fbn
	}
	return fieldName
}

// fbConcat concatenates two names in a way consistent with how flatc generates Go code.
// In general, the names can just be concatenated, but if the first name ends in "_",
// then the second name must start with a lowercase letter.
func fbConcat(s1, s2 string) string {
	if s2 == "" {
		return s1
	}
	if strings.HasSuffix(s1, "_") {
		h := strings.ToLower(s2[0:1])
		s2 = h + s2[1:]

	}
	return s1 + s2
}

var optionalExprs map[[2]string]struct{} = map[[2]string]struct{}{
	{"CallExpression", "Pipe"}: struct{}{},
}

// isOptionalExpr returns true if the given struct field (with type Expression)
// is optional.
func isOptionalExpr(structName, fieldName string) bool {
	key := [2]string{structName, fieldName}
	_, ok := optionalExprs[key]
	return ok
}

// doBasic returns a slice of Go code that populates a field in a Go semantic graph struct,
// when the Go field type is a primitive Go type.
func doBasic(o types.Object, field *types.Var) ([]jen.Code, error) {
	var codes []jen.Code
	t := field.Type().(*types.Basic)
	switch n := t.Name(); n {
	case "string":
		codes = append(codes,
			jen.Id("rcv").Dot(field.Name()).Op("=").String().Params(jen.Id("fb").Dot(field.Name()).Params()),
		)
	case "float64", "bool", "int64", "uint64":
		codes = append(codes,
			jen.Id("rcv").Dot(field.Name()).Op("=").Id("fb").Dot(field.Name()).Params(),
		)
	default:
		var err error
		if codes, err = handleMissingf(codes, "unknown basic type: "+n); err != nil {
			return nil, err
		}
	}
	return codes, nil
}

// doSlice returns a slice of Go code that populates a field in a Go semantic graph struct,
// when the Go field type is a slice.
func doSlice(o types.Object, field *types.Var) ([]jen.Code, error) {
	fieldForError := o.Name() + "." + field.Name()
	var codes []jen.Code
	sl := field.Type().(*types.Slice)
	switch t := sl.Elem().(type) {
	case *types.Pointer:
		switch t := t.Elem().(type) {
		case *types.Named:
			switch n := t.Obj().Name(); n {
			case "File", "ImportDeclaration", "FunctionParameter", "Property":
				codes = append(codes, genLoop(o, n, field)...)
			default:
				var err error
				if codes, err = handleMissingf(codes, "unknown slice of ptr to named type: %v", n); err != nil {
					return nil, err
				}
			}
		default:
			var err error
			if codes, err = handleMissingf(codes, "unknown slice elem ptr type: %#v", t); err != nil {
				return nil, err
			}
		}
	case *types.Named:
		switch n := t.Obj().Name(); n {
		case "Statement", "Expression":
			codes = append(codes, genLoop(o, n, field)...)
		case "Duration", "StringExpressionPart":
			helperFn := "fromFB" + n + "Vector"
			codes = append(codes,
				ifErrorPropagate(
					jen.Id(helperFn).Params(jen.Id("fb")),
					fieldForError,
					jen.Id("rcv").Dot(field.Name()),
				),
			)
		default:
			var err error
			if codes, err = handleMissingf(codes, "unknown slice of named type: %v", n); err != nil {
				return nil, err
			}
		}
	default:
		var err error
		if codes, err = handleMissingf(codes, "unknown slice elem type: %#v", t); err != nil {
			return nil, err
		}
	}
	return codes, nil
}

// genLoop generates a loop for populating a slice field.
func genLoop(o types.Object, astElemTypeName string, field *types.Var) []jen.Code {
	var codes []jen.Code

	fbFieldName := toFBName(o.Name(), field.Name())
	lenMethodCall := jen.Id("fb").Dot(fbFieldName + "Length").Params()
	elemType := field.Type().(*types.Slice).Elem()
	isInterface := types.IsInterface(elemType)
	_, isPointer := elemType.(*types.Pointer)
	elemTyp := jen.Index()
	if isPointer {
		elemTyp.Add(jen.Op("*"))
	}
	elemTyp.Add(jen.Id(astElemTypeName))

	var loopBody []jen.Code
	{
		fieldForError := o.Name() + "." + field.Name()

		var fbTypeName string
		if isInterface {
			fbTypeName = "Wrapped" + astElemTypeName
		} else {
			fbTypeName = astElemTypeName
		}
		fbElemType := jen.Id("fbsemantic").Dot(fbTypeName)
		fbElemVar := jen.Id("fb" + fbTypeName)

		loopBody = append(loopBody, fbElemVar.Clone().Op(":=").New(fbElemType))
		loopBody = append(loopBody,
			jen.If(jen.Op("!").Id("fb").Dot(fbFieldName).Params(fbElemVar, jen.Id("i"))).Block(
				returnErrorf("could not deserialize %v", fieldForError),
			),
		)

		if isInterface {
			loopBody = append(loopBody,
				ifErrorPropagate(
					jen.Id("from"+fbTypeName).Params(fbElemVar),
					fieldForError,
					jen.Id("rcv").Dot(field.Name()).Index(jen.Id("i")),
				),
			)
		} else {
			loopBody = append(loopBody,
				jen.Id("rcv").Dot(field.Name()).Index(jen.Id("i")).Op("=").New(jen.Id(astElemTypeName)),
				ifErrorPropagate(
					jen.Id("rcv").Dot(field.Name()).Index(jen.Id("i")).Dot("FromBuf").Params(jen.Id("fb"+astElemTypeName)),
					fieldForError,
				),
			)
		}
	}

	codes = append(codes,
		jen.If(
			lenMethodCall.Clone().Add(jen.Op(">").Lit(0)),
		).Block(
			jen.Id("rcv").Dot(field.Name()).Op("=").
				Make(elemTyp, lenMethodCall),
			jen.For(
				jen.Id("i").Op(":=").Lit(0),
				jen.Id("i").Op("<").Add(lenMethodCall),
				jen.Id("i").Op("++"),
			).Block(
				loopBody...,
			),
		),
	)
	return codes
}

// doPointer returns a slice of Go code that populates a field in a Go semantic graph struct,
// when the Go field type is a pointer to some other type.
func doPointer(o types.Object, field *types.Var) ([]jen.Code, error) {
	fieldForError := o.Name() + "." + field.Name()
	var cs []jen.Code
	sl := field.Type().(*types.Pointer)
	switch t := sl.Elem().(type) {
	case *types.Named:
		switch n := t.Obj().Name(); n {
		case "PackageClause", "Identifier", "ObjectExpression",
			"StringLiteral", "MemberExpression", "IdentifierExpression",
			"NativeVariableAssignment", "FunctionParameters", "FunctionBlock":
			// CallExpression's argument in the Rust semantic graph and in the semantic
			// Flatbuffers schema is defined as list of properties. This is different
			// from the Go semantic graph where it is a pointer to an ObjectExpression.
			// This code helps to bridge that difference.
			if o.Name() == "CallExpression" && field.Name() == "Arguments" {
				cs = append(cs,
					ifErrorPropagate(
						jen.Id("objectExprFromProperties").Params(jen.Id("fb")),
						fieldForError,
						jen.Id("rcv").Dot(field.Name()),
					),
				)
			} else {
				fbVar := jen.Id("fb" + field.Name())
				fbField := toFBName(o.Name(), field.Name())
				cs = append(cs,
					jen.If(
						fbVar.Clone().Op(":=").Id("fb").Dot(fbField).Params(jen.Nil()),
						fbVar.Clone().Op("!=").Nil(),
					).Block(
						jen.Id("rcv").Dot(field.Name()).Op("=").New(jen.Id(n)),
						ifErrorPropagate(
							jen.Id("rcv").Dot(field.Name()).
								Dot("FromBuf").Params(fbVar),
							fieldForError,
						),
					),
				)
			}

		case "Regexp":
			cs = append(cs,
				ifErrorPropagate(
					jen.Id("fromFBRegexpLiteral").Params(jen.Id("fb").Dot(field.Name()).Params()),
					fieldForError,
					jen.Id("rcv").Dot(field.Name()),
				),
			)
		case "MonoType":
			cs = append(cs,
				ifErrorPropagate(
					jen.Id("getMonoType").Params(jen.Id("fb")),
					fieldForError,
					jen.Id("rcv").Dot(field.Name()),
				),
			)
		case "PolyType":
			cs = append(cs,
				ifErrorPropagate(
					jen.Id("getPolyType").Params(jen.Id("fb")),
					fieldForError,
					jen.Id("rcv").Dot(field.Name()),
				),
			)
		default:
			var err error
			if cs, err = handleMissingf(cs, "unknown pointer to named type: %#v", n); err != nil {
				return nil, err
			}
		}
	default:
		var err error
		if cs, err = handleMissingf(cs, "unknown pointer elem type: %#v", t); err != nil {
			return nil, err
		}
	}
	return cs, nil
}

// ifErrorPropagate generates a Go "if" statement of the form:
//   if <assignee1>, <assignee2>, ..., err = <call>; err != nil {
//     return errors.Wrap(err, codes.Inherit, <wrapMsg>)
//   }
func ifErrorPropagate(call jen.Code, wrapMsg string, assignees ...jen.Code) jen.Code {
	assignees = append(assignees, jen.Err())
	lhs := jen.List(assignees...)
	var ret jen.Code
	if wrapMsg != "" {
		ret = jen.Return(jen.Qual(errorsPath, "Wrap")).Params(jen.Err(), jen.Qual(codesPath, "Inherit"), jen.Lit(wrapMsg))
	} else {
		ret = jen.Return(jen.Err())
	}
	return jen.If(
		lhs.Op("=").Add(call),
		jen.Err().Op("!=").Nil(),
	).Block(
		ret,
	)
}

// returnErrorf generates Go code for an error:
//   return errors.New(codes.Internal, <msg>)
// Arguments to this function are similar to fmt.Printf.
func returnErrorf(msgf string, args ...interface{}) jen.Code {
	msg := fmt.Sprintf(msgf, args...)
	return jen.Return(jen.Qual(errorsPath, "New").Params(
		jen.Qual(codesPath, "Internal"),
		jen.Lit(msg),
	))
}

// skipNodes defines a set of nodes to skip,
// either because they have no FlatBuffers representation, or because
// their deserialization is easier to implement by hand.
var skipNodes = map[string]struct{}{
	"Extern":                     {},
	"ExternBlock":                {},
	"ExternalVariableAssignment": {},
	"InterpolatedPart":           {},
	"TextPart":                   {},
	"FunctionExpression":         {},
	"FunctionBlock":              {},
	"FunctionParameter":          {},
	"FunctionParameters":         {},
}

// skipObject will return false if:
// - the object type has a name (i.e., like "type foo struct {}")
// - the object type implements the Node interface (has a node() method)
// - is not in skipNodes
// and true otherwise.
func skipObject(o types.Object) bool {
	named, ok := o.Type().(*types.Named)
	if !ok {
		return true
	}
	if !isSemanticNode(named) {
		return true
	}

	_, ok = skipNodes[o.Name()]
	return ok
}

// isSemanticNode returns true if the given named type implements the Node interface,
// i.e., it has a method called "node".
func isSemanticNode(named *types.Named) bool {
	if named.NumMethods() == 0 {
		return false
	}
	for i := 0; i < named.NumMethods(); i++ {
		m := named.Method(i)
		if m.Name() == "node" {
			return true
		}
	}
	return false
}
